#以下所有函数如无特别说明，输入参数S均为numpy序列或者列表list，N为整型int
#应用层1级函数完美兼容通达信或同花顺，具体使用方法请参考通达信

import numpy as np
import pandas as pd
import math
import copy

#------------------ 0级：核心工具函数 --------------------------------------------      
def RD(N,D=3):   return np.round(N,D)        #四舍五入取3位小数 
def RET(S,N=1):  return np.array(S)[-N]      #返回序列倒数第N个值,默认返回最后一个，用于取大型，最后一i谈
def ABS(S):      return np.abs(S)            #返回N的绝对值
def LN(S):       return np.log(S)            #求底是e的自然对数,
def POW(S,N):    return np.power(S,N)        #求S的N次方
def SQRT(S):     return np.sqrt(S)           #求S的平方根
def MAX(S1,S2):  return np.maximum(S1,S2)    #序列max
def MAX3(S1,S2,S3): return np.maximum(S1,S2,S3)  # 3列中的最大值，序列不能计算
def MAX4(S1,S2,S3,S4):  return  np.maximum(S1,S2,S3,S4)
def MIN(S1,S2):  return np.minimum(S1,S2)    #序列min
def IF(S,A,B):   return np.where(S,A,B)      #序列布尔判断 return=A  if S==True  else  B
def ArcTN(S):      return  np.arctan(S)        #求 S的反正切序列

def REF(S, N=1):          #对序列整体下移动N,返回序列(shift后会产生NAN)    
    return pd.Series(S).shift(N).values

def REFS(S,N=1):
    return pd.Series(S).shift(N)

def DIFF(S, N=1):         #前一个值减后一个值,前面会产生nan 
    return pd.Series(S).diff(N).values     #np.diff(S)直接删除nan，会少一行

def STD(S,N):             #求序列的N日标准差，返回序列    
    return  pd.Series(S).rolling(N).std(ddof=0).values     

def SUM(S, N):            #对序列求N天累计和，返回序列    N=0对序列所有依次求和         
    return pd.Series(S).rolling(N).sum().values if N>0 else pd.Series(S).cumsum().values  

def CONST(S):             #返回序列S最后的值组成常量序列
    return np.full(len(S),S[-1])
  
def HHV(S,N):             #HHV(C, 5) 最近5天收盘最高价        
    return pd.Series(S).rolling(N).max().values     

def LLV(S,N):             #LLV(C, 5) 最近5天收盘最低价     
    return pd.Series(S).rolling(N).min().values

def PctChg(C,S,N=1): # 计算 当日与前面几天的 涨幅,第一个参数要是 收盘价，第二个参数，可以是最高，最低，和收盘价
    return RD( ( S- REF(C,N)) /(REF(C,N)) * 100)

def PctChg2(C,S,N=1): # 计算今天的与前面几天的 涨幅,第一个参数要是 收盘价，第二个参数，可以是最高，最低，和收盘价
    try:
        rsult= (S[-1] - C[-N-1]) / (C[-N-1]) * 100
        return RD(rsult )
    except Exception as e:
        return None

def Pct_future_Chg(C,S,N=-1): # 计算今天的与前面几天的 涨幅,第一个参数要是 收盘价，第二个参数，可以是最高，最低，和收盘价
    return RD( ( REF(S,N)-C) /C * 100)

def maxPctChg(H,C, N=5):  # 计算截至今天与前面几天内最大值的最大涨幅，往往给H值。
    max_inN=HHV(H,N)
    return RD((max_inN-REF(C,N))/(REF(C,N))*100)

def minPctChg(L,C, N=5):  # 计算截至今天与前面几天，最小值的涨幅，也就是最大跌幅，往往给 Low值
    min_inN=LLV(L,N)
    return RD((min_inN)-REF(C,N))/(REF(C,N)*100)

def max_futurePctChg(H,C,N=-5):
    max_infueN=HHV(H,N)
    return RD((max_infueN-C)/C*100)

def min_futurePctChg(L,C,N=-5):
    min_infueN=LLV(L,N)
    return RD((min_infueN-C)/C*100)

def HHVBARS(S,N):         #求N周期内S最高值到当前周期数, 返回序列
    return pd.Series(S).rolling(N).apply(lambda x: np.argmax(x[::-1]),raw=True).values 

def LLVBARS(S,N):         #求N周期内S最低值到当前周期数, 返回序列
    return pd.Series(S).rolling(N).apply(lambda x: np.argmin(x[::-1]),raw=True).values    
  
def MA(S,N):              #求序列的N日简单移动平均值，返回序列                    
    return pd.Series(S).rolling(N).mean().values  
  
def EMA(S,N):             #指数移动平均,为了精度 S>4*N  EMA至少需要120周期     alpha=2/(span+1)       这个还用
    return pd.Series(S).ewm(span=N, adjust=False).mean().values

def EXP(S):             #指数移动平均,为了精度 S>4*N  EMA至少需要120周期     alpha=2/(span+1)       这个还用
    return np.exp(S).values

def SMA(S, N, M=1):       #中国式的SMA,至少需要120周期才精确 (雪球180周期)    alpha=1/(1+com)     这个好用
    return pd.Series(S).ewm(alpha=M/N,adjust=False).mean().values           #com=N-M/M

def MEMA(S,N):         #通达信中的 MEMA 就是SMA的平替
    return SMA(S,N,1)
def WMA(S, N):            #通达信S序列的N日加权移动平均 Yn = (1*X1+2*X2+3*X3+...+n*Xn)/(1+2+3+...+Xn)   这个好用
    return pd.Series(S).rolling(N).apply(lambda x:x[::-1].cumsum().sum()*2/N/(N+1),raw=True).values 

def DMA(S, A):            #求S的动态移动平均，A作平滑因子,必须 0<A<1  (此为核心函数，非指标）    这个好用 ，A 用换手率，换手率越大，当日权重越大
    if isinstance(A,(int,float)):  return pd.Series(S).ewm(alpha=A,adjust=False).mean().values    
    A=np.array(A);   A[np.isnan(A)]=1.0;   Y= np.zeros(len(S));   Y[0]=S[0]     
    for i in range(1,len(S)): Y[i]=A[i]*S[i]+(1-A[i])*Y[i-1]      #A支持序列 by jqz1226         
    return Y             
  
def AVEDEV(S, N):         #平均绝对偏差  (序列与其平均值的绝对差的平均值)   
    return pd.Series(S).rolling(N).apply(lambda x: (np.abs(x - x.mean())).mean()).values 

def SLOPE(S, N):          #返S序列N周期回线性回归斜率              这个 更好，可以设想到，10日线性回归斜率，向下，就不买了，
    return pd.Series(S).rolling(N).apply(lambda x: np.polyfit(range(N),x,deg=1)[0],raw=True).values

def FORCAST(S, N):        #返回S序列N周期回线性回归后的预测值， jqz1226改进成序列出    
    return pd.Series(S).rolling(N).apply(lambda x:np.polyval(np.polyfit(range(N),x,deg=1),N-1),raw=True).values  

def LAST(S, A, B):        #从前A日到前B日一直满足S_BOOL条件, 要求A>B & A>0 & B>=0    前20天到前5天， 一直满足 斜率大于-1 小于0.2
    return np.array(pd.Series(S).rolling(A+1).apply(lambda x:np.all(x[::-1][B:]),raw=True),dtype=bool)
  
#------------------   1级：应用层函数(通过0级核心函数实现）使用方法请参考通达信--------------------------------
def COUNT(S, N):                       # COUNT(CLOSE>O, N):  最近N天满足S_BOO的天数  True的天数
    return SUM(S,N)    

def EVERY(S, N):                       # EVERY(CLOSE>O, 5)   最近N天是否都是True，   3天内，都在20日均线上，3天内都在通道向上
    return  IF(SUM(S,N)==N,True,False)                    
  
def EXIST(S, N):                       # EXIST(CLOSE>3010, N=5)  n日内是否存在一天大于3000点  
    return IF(SUM(S,N)>0,True,False)

def EXSITER(S,N,M):    # 例如 N=70 ,M=5  测试 70天前到 5天前的 条件是否都成立
    aa = COUNT(S, N)
    bb = COUNT(S, M)
    cc = aa - bb
    return IF (cc > 0, True, False)


def FILTER(S, N):                      # FILTER函数，S满足条件后，将其后N周期内的数据置为0, FILTER(C==H,5)
    for i in range(len(S)): S[i+1:i+1+N]=0  if S[i] else S[i+1:i+1+N]        
    return S                           # 例：FILTER(C==H,5) 涨停后，后5天不再发出信号 
  
def BARSLAST(S):                       #上一次条件成立到当前的周期, BARSLAST(C/REF(C,1)>=1.1) 上一次涨停到今天的天数 
    M=np.concatenate(([0],np.where(S,1,0)))  
    for i in range(1, len(M)):  M[i]=0 if M[i] else M[i-1]+1    
    return M[1:]                       

def BARSLASTCOUNT(S):                  # 统计连续满足S条件的周期数        by jqz1226
    rt = np.zeros(len(S)+1)            # BARSLASTCOUNT(CLOSE>OPEN)表示统计连续收阳的周期数
    for i in range(len(S)): rt[i+1]=rt[i]+1  if S[i] else rt[i+1]
    return rt[1:]  
  
def BARSSINCEN(S, N):                  # N周期内第一次S条件成立到现在的周期数,N为常量  by jqz1226
    return pd.Series(S).rolling(N).apply(lambda x:N-1-np.argmax(x) if np.argmax(x) or x[0] else 0,raw=True).fillna(0).values.astype(int)
  
def CROSS(S1, S2):                     # 判断向上金叉穿越 CROSS(MA(C,5),MA(C,10))  判断向下死叉穿越 CROSS(MA(C,10),MA(C,5))   
    return np.concatenate(([False], np.logical_not((S1>S2)[:-1]) & (S1>S2)[1:]))    # 不使用0级函数,移植方便  by jqz1226
    
def LONGCROSS(S1,S2,N):                # 两条线维持一定周期后交叉,S1在N周期内都小于S2,本周期从S1下方向上穿过S2时返回1,否则返回0         
    return  np.array(np.logical_and(LAST(S1<S2,N,1),(S1>S2)),dtype=bool)            # N=1时等同于CROSS(S1, S2)
    
def VALUEWHEN(S, X):                   # 当S条件成立时,取X的当前值,否则取VALUEWHEN的上个成立时的X值   by jqz1226
    return pd.Series(np.where(S,X,np.nan)).ffill().values  

def BETWEEN(S, A, B):                  # S处于A和B之间时为真。 等于通达信的RANGE,包括 A<S<B 或 A>S>B， 用于判断指标指， 是否在某个范围内，就是安全的
    return ((A<S) & (S<B)) | ((A>S) & (S>B))  

def TOPRANGE(S):                       # TOPRANGE(HIGH)表示当前最高价是近多少周期内最高价的最大值 by jqz1226
    rt = np.zeros(len(S))
    for i in range(1,len(S)):  rt[i] = np.argmin(np.flipud(S[:i]<S[i]))
    return rt.astype('int')

def LOWRANGE(S):                       # LOWRANGE(LOW)表示当前最低价是近多少周期内最低价的最小值 by jqz1226
    rt = np.zeros(len(S))
    for i in range(1,len(S)):  rt[i] = np.argmin(np.flipud(S[:i]>S[i]))
    return rt.astype('int')
  
  
#------------------   2级：技术指标函数(全部通过0级，1级函数实现） ------------------------------
def MACD(CLOSE,SHORT=12,LONG=26,M=9):             # EMA的关系，S取120日，和雪球小数点2位相同
    DIF = EMA(CLOSE,SHORT)-EMA(CLOSE,LONG);  
    DEA = EMA(DIF,M);      MACD=(DIF-DEA)*2
    return RD(DIF),RD(DEA),RD(MACD)

def VMACD(VOL,SHORT=12,LONG=26,MID=9):
    DIF= EMA(VOL, SHORT) - EMA(VOL, LONG)
    DEA= EMA(DIF, MID)
    MACD= DIF - DEA
    return RD(DIF),RD(DEA),RD(MACD)

def KDJ(CLOSE,HIGH,LOW, N=9,M1=3,M2=3):           # KDJ指标
    RSV = (CLOSE - LLV(LOW, N)) / (HHV(HIGH, N) - LLV(LOW, N)) * 100
    K = EMA(RSV, (M1*2-1));    D = EMA(K,(M2*2-1));        J=K*3-D*2
    return K, D, J

def SKDJ(CLOSE,HIGH,LOW,N=9,M=3): #慢速K，D 策略， 结合周线KDJ 向上，日线kd 金叉，信号更稳定。
    LOWV= LLV(LOW, N);HIGHV = HHV(HIGH, N);
    RSV = EMA((CLOSE - LOWV) / (HIGHV - LOWV) * 100, M)
    K= EMA(RSV, M);    D= MA(K, M)
    return K,D


def RSI(CLOSE, N=24):                           # RSI指标,和通达信小数点2位相同
    DIF = CLOSE-REF(CLOSE,1) 
    return RD(SMA(MAX(DIF,0), N) / SMA(ABS(DIF), N) * 100)  

def WR(CLOSE, HIGH, LOW, N=10, N1=6):            #W&R 威廉指标
    WR = (HHV(HIGH, N) - CLOSE) / (HHV(HIGH, N) - LLV(LOW, N)) * 100
    WR1 = (HHV(HIGH, N1) - CLOSE) / (HHV(HIGH, N1) - LLV(LOW, N1)) * 100
    return RD(WR), RD(WR1)

def CR(CLOSE,HIGH,LOW,N=20):                        #CR价格动量指标
    MID=REF(HIGH+LOW+CLOSE,1)/3;
    return SUM(MAX(0,HIGH-MID),N)/SUM(MAX(0,MID-LOW),N)*100

def BIAS(CLOSE,L1=6, L2=12, L3=24):              # BIAS乖离率
    BIAS1 = (CLOSE - MA(CLOSE, L1)) / MA(CLOSE, L1) * 100
    BIAS2 = (CLOSE - MA(CLOSE, L2)) / MA(CLOSE, L2) * 100
    BIAS3 = (CLOSE - MA(CLOSE, L3)) / MA(CLOSE, L3) * 100
    return RD(BIAS1), RD(BIAS2), RD(BIAS3)

def BOLL(CLOSE,N=20, P=2):                       #BOLL指标，布林带    
    MID = MA(CLOSE, N); 
    UPPER = MID + STD(CLOSE, N) * P
    LOWER = MID - STD(CLOSE, N) * P
    return RD(UPPER), RD(MID), RD(LOWER)    

def PSY(CLOSE,N=12, M=6):  
    PSY=COUNT(CLOSE>REF(CLOSE,1),N)/N*100
    PSYMA=MA(PSY,M)
    return RD(PSY),RD(PSYMA)

def CCI(CLOSE,HIGH,LOW,N=84):  # CCI 做长线线，中线，N设置为 84，也就是4个月，一个季度，就能完成减仓
    TP=(HIGH+LOW+CLOSE)/3
    return (TP-MA(TP,N))/(0.015*AVEDEV(TP,N))
        
def ATR(CLOSE,HIGH,LOW, N=20):                    #真实波动N日平均值
    TR = MAX(MAX((HIGH - LOW), ABS(REF(CLOSE, 1) - HIGH)), ABS(REF(CLOSE, 1) - LOW))
    return MA(TR, N)

def BBI(CLOSE,M1=3,M2=6,M3=12,M4=20):             #BBI多空指标   
    return (MA(CLOSE,M1)+MA(CLOSE,M2)+MA(CLOSE,M3)+MA(CLOSE,M4))/4    

def DMI(CLOSE,HIGH,LOW,M1=14,M2=6):               #动向指标：结果和同花顺，通达信完全一致  短线，信号交叉和整体趋势判断结合
    TR = SUM(MAX(MAX(HIGH - LOW, ABS(HIGH - REF(CLOSE, 1))), ABS(LOW - REF(CLOSE, 1))), M1)
    HD = HIGH - REF(HIGH, 1);     LD = REF(LOW, 1) - LOW
    DMP = SUM(IF((HD > 0) & (HD > LD), HD, 0), M1)
    DMM = SUM(IF((LD > 0) & (LD > HD), LD, 0), M1)
    PDI = DMP * 100 / TR;         MDI = DMM * 100 / TR
    ADX = MA(ABS(MDI - PDI) / (PDI + MDI) * 100, M2)
    ADXR = (ADX + REF(ADX, M2)) / 2
    return PDI, MDI, ADX, ADXR  

def TAQ(HIGH,LOW,N):                               #唐安奇通道(海龟)交易指标，大道至简，能穿越牛熊，信号太差，没软用
    UP=HHV(HIGH,N);    DOWN=LLV(LOW,N);    MID=(UP+DOWN)/2
    return UP,MID,DOWN

def KTN(CLOSE,HIGH,LOW,N=20,M=10):                 #肯特纳交易通道, N选20日，ATR选10日
    MID=EMA((HIGH+LOW+CLOSE)/3,N)
    ATRN=ATR(CLOSE,HIGH,LOW,M)
    UPPER=MID+2*ATRN;   LOWER=MID-2*ATRN
    return UPPER,MID,LOWER       
  
def TRIX(CLOSE,M1=12, M2=20):                      #三重指数平滑平均线
    TR = EMA(EMA(EMA(CLOSE, M1), M1), M1)
    TRIX = (TR - REF(TR, 1)) / REF(TR, 1) * 100
    TRMA = MA(TRIX, M2)
    return TRIX, TRMA

def VR(CLOSE,VOL,M1=26):                            #VR容量比率
    LC = REF(CLOSE, 1)
    return SUM(IF(CLOSE > LC, VOL, 0), M1) / SUM(IF(CLOSE <= LC, VOL, 0), M1) * 100

def EMV(HIGH,LOW,VOL,N=14,M=9):                     #简易波动指标 
    VOLUME=MA(VOL,N)/VOL;       MID=100*(HIGH+LOW-REF(HIGH+LOW,1))/(HIGH+LOW)
    EMV=MA(MID*VOLUME*(HIGH-LOW)/MA(HIGH-LOW,N),N);    MAEMV=MA(EMV,M)
    return EMV,MAEMV


def DPO(CLOSE,M1=20, M2=10, M3=6):                  #区间震荡线
    DPO = CLOSE - REF(MA(CLOSE, M1), M2);    MADPO = MA(DPO, M3)
    return DPO, MADPO

def BRAR(OPEN,CLOSE,HIGH,LOW,M1=26):                 #BRAR-ARBR 情绪指标  
    AR = SUM(HIGH - OPEN, M1) / SUM(OPEN - LOW, M1) * 100
    BR = SUM(MAX(0, HIGH - REF(CLOSE, 1)), M1) / SUM(MAX(0, REF(CLOSE, 1) - LOW), M1) * 100
    return AR, BR

def DFMA(CLOSE,N1=10,N2=50,M=10):                    #平行线差指标           这个值得一看呀，波段用 可以结合其他指标，形成共振
    DIF=MA(CLOSE,N1)-MA(CLOSE,N2); DIFMA=MA(DIF,M)   #通达信指标叫DMA 同花顺叫新DMA
    return DIF,DIFMA

def MTM(CLOSE,N=12,M=6):                             #动量指标
    MTM=CLOSE-REF(CLOSE,N);         MTMMA=MA(MTM,M)
    return MTM,MTMMA

def MASS(HIGH,LOW,N1=9,N2=25,M=6):                   #梅斯线
    MASS=SUM(MA(HIGH-LOW,N1)/MA(MA(HIGH-LOW,N1),N1),N2)
    MA_MASS=MA(MASS,M)
    return MASS,MA_MASS
  
def ROC(CLOSE,N=12,M=6):                             #变动率指标
    ROC=100*(CLOSE-REF(CLOSE,N))/REF(CLOSE,N);    MAROC=MA(ROC,M)
    return ROC,MAROC  

def EXPMA(CLOSE,N1=12,N2=50):                        #EMA指数平均数指标
    return EMA(CLOSE,N1),EMA(CLOSE,N2);

def OBV(CLOSE,VOL):                                  #能量潮指标
    return SUM(IF(CLOSE>REF(CLOSE,1),VOL,IF(CLOSE<REF(CLOSE,1),-VOL,0)),0)/10000

def MFI(CLOSE,HIGH,LOW,VOL,N=14):                    #MFI指标是成交量的RSI指标
    TYP = (HIGH + LOW + CLOSE)/3
    V1=SUM(IF(TYP>REF(TYP,1),TYP*VOL,0),N)/SUM(IF(TYP<REF(TYP,1),TYP*VOL,0),N)  
    return 100-(100/(1+V1))     
  
def ASI(OPEN,CLOSE,HIGH,LOW,M1=26,M2=10):            #振动升降指标
    LC=REF(CLOSE,1);      AA=ABS(HIGH-LC);     BB=ABS(LOW-LC);
    CC=ABS(HIGH-REF(LOW,1));   DD=ABS(LC-REF(OPEN,1));
    R=IF( (AA>BB) & (AA>CC),AA+BB/2+DD/4,IF( (BB>CC) & (BB>AA),BB+AA/2+DD/4,CC+DD/4));
    X=(CLOSE-LC+(CLOSE-OPEN)/2+LC-REF(OPEN,1));
    SI=16*X/R*MAX(AA,BB);   ASI=SUM(SI,M1);   ASIT=MA(ASI,M2);
    return ASI,ASIT   

def XSII(CLOSE, HIGH, LOW, N=102, M=7):              #薛斯通道II      可以用于超短线，早盘追涨，判断通道向上的。 才行
    AA  = MA((2*CLOSE + HIGH + LOW)/4, 5)            #最新版DMA才支持 2021-12-4
    TD1 = AA*N/100;   TD2 = AA*(200-N) / 100
    CC =  ABS((2*CLOSE + HIGH + LOW)/4 - MA(CLOSE,20))/MA(CLOSE,20)
    DD =  DMA(CLOSE,CC);    TD3=(1+M/100)*DD;      TD4=(1-M/100)*DD
    return TD1, TD2, TD3, TD4
#
def my_HHV(S, S2):  # HHV,支持N为序列版本
        # type: (np.ndarray, Optional[int,float, np.ndarray]) -> np.ndarray
        """
        HHV(C, 5)  # 最近5天收盘最高价
        """
        if isinstance(S2, (int, float)):
            return pd.Series(S).rolling(S2).max().values
        else:
            res = np.repeat(np.nan, len(S))
            for i in range(len(S)):
                if (not np.isnan(S2[i])) and S2[i] <= i + 1:
                    res[i] = S[i + 1 - S2[i]:i + 1].max()
            return res

# def LLV(S, N):  # LLV,支持N为序列版本
#         # type: (np.ndarray, Optional[int,float, np.ndarray]) -> np.ndarray
#         """
#         LLV(C, 5)  # 最近5天收盘最低价
#         """
#         if isinstance(N, (int, float)):
#             return pd.Series(S).rolling(N).min().values
#         else:
#             res = np.repeat(np.nan, len(S))
#             for i in range(len(S)):
#                 if (not np.isnan(N[i])) and N[i] <= i + 1:
#                     res[i] = S[i + 1 - N[i]:i + 1].min()
#             return res

###########################第一个参数是 序列，第二个参数也是序列，用于回测#########################
def my_LLVBARS(tmp_a, tmp_b):
    rt_arr = []
    for i in range(0, len(tmp_b)):
        # tmp_num = tmp_b[i]
        temp_num=int(tmp_b[i]  )         # 序列值
        if((i+1-temp_num)>0):
            start_idx=i+1-temp_num
        else:
            start_idx=0

        tmp_arr = tmp_a[start_idx: i + 1]
        wz = 0
        if len(tmp_arr) > 0:
            for j in range(0, len(tmp_arr)):
                if tmp_arr[j] == min(tmp_arr):
                    wz = len(tmp_arr) - j - 1
                    rt_arr.append(wz)
                    break
        else:
            rt_arr.append(0)
    return pd.Series(rt_arr)

def my_HHVBARS(tmp_a, tmp_b):
    rt_arr = []
    for i in range(0, len(tmp_b)):
        # tmp_num = tmp_b[i]
        temp_num = int(tmp_b[i] ) # 序列值
        if ((i + 1 - temp_num) > 0):
            start_idx = i + 1 - temp_num
        else:
            start_idx = 0

        tmp_arr = tmp_a[start_idx: i + 1]
        wz = 0
        if len(tmp_arr) > 0:
            for j in range(0, len(tmp_arr)):
                if tmp_arr[j] == max(tmp_arr):
                    wz = len(tmp_arr) - j - 1
                    rt_arr.append(wz)
                    break
        else:
            rt_arr.append(0)
    return pd.Series(rt_arr)


def my_EVERY(S, S2):  # EVERY(CLOSE>O, 5)   最近N天是否都是True，   3天内，都在20日均线上，3天内都在通道向上
    rt_arr = []
    for i in range(0, len(S2)):
        tmp_num = S2[i]
        tmp_arr = S[i + 1 - tmp_num: i + 1]
        wz = 0
        if len(tmp_arr) > 0:
            for j in range(0, len(tmp_arr)):
                if tmp_arr[j] == True:
                    wz = wz+1
            if wz==len(tmp_arr):
                rt_arr.append(True)
            else:
                rt_arr.append(False)
        else:
            rt_arr.append(False)
    return pd.Series(rt_arr).values


def my_REF(close_a,series_b):
    rt_arr = []
    for i in range(0, len(series_b)):
        aa = int(series_b[i])
        tmp_arr = close_a[:i + 1]
        if aa>=len(tmp_arr):
            bb= 0
        else:
            bb = tmp_arr[-aa - 1]
        rt_arr.append(bb)
    return pd.Series(rt_arr).values


def my_LLV(tmp_a, tmp_b):
    rt_arr = []
    for i in range(0, len(tmp_b)):
        tmp_num = tmp_b[i]
        tmp_arr = tmp_a[i + 1 - tmp_num: i + 1]
        if len(tmp_arr) > 0:
            aaa = min(tmp_arr)
            rt_arr.append(aaa)
        else:
            rt_arr.append(0)
    return pd.Series(rt_arr).values

# 动态序列求和
def my_SUM(S1,S2):
    rt_arr=[]
    for i in range(0,len(S2)):  # 动态序列
        temp_num=S2[i]           # 序列值
        if((i+1-temp_num)>0):
            start_idx=i+1-temp_num
        else:
            start_idx=0
        #   起始位置
        temp_arr=S1[start_idx:i+1]   # 动态小数组
        if len(temp_arr) > 0:
            aaa = sum(temp_arr)
            rt_arr.append(aaa)
        else:
            rt_arr.append(S1[i])
    return pd.Series(rt_arr).values

def my_COUNT(S1,S2):
    SS1= IF(S1==True,1,0)
    rt_arr=[]
    for i in range(0,len(S2)):  # 动态序列
        temp_num=S2[i]           # 序列值
        if((i+1-temp_num)>0):
            start_idx=i+1-temp_num
        else:
            start_idx=0
        #   起始位置
        temp_arr=SS1[start_idx:i+1]   # 动态小数组
        if len(temp_arr) > 0:
            aaa = sum(temp_arr)
            rt_arr.append(aaa)
        else:
            rt_arr.append(SS1[i])
    return pd.Series(rt_arr).values


def tdx_gs_tmp(tmp_data):
    h = tmp_data
    # print(h)
    CLOSE = h['close'].values
    HIGH = h['high'].values
    LOW = h['low'].values
    OPEN = h['open'].values
    VOL = h['volume'].values
    AMOUNT = h['money'].values

    # 长期均线与短期均线
    MA5 = MA(CLOSE, 5)  # 表示5日均线
    MA10 = MA(CLOSE, 10)  # 表示10日均线
    MA200 = MA(CLOSE, 200)  # 表示200日长期均线
    DIFF = EMA(CLOSE, 12) - EMA(CLOSE, 26)
    DEA = EMA(DIFF, 9)
    MACD = 2 * (DIFF - DEA)
    HAHA = REF(MACD, 1) < 0
    HEHE = MACD > 0
    WW = HAHA & HEHE
    QZQ = BARSLAST(WW == True)
    QZQ2 = QZQ + 20
    QM = my_LLVBARS(MACD, QZQ2)
    tmp_MQDZ = my_REF(MACD, QM)
    MQDZ = IF(QM > QZQ, tmp_MQDZ, 0)
    aa1 = (MACD < 0)
    aa2 = (MACD > MQDZ)
    tmp_LLV = my_LLV(CLOSE, QZQ2)
    aa3 = (CLOSE == tmp_LLV)
    aa4 = (MQDZ < 0)
    aa5 = (MA200 > REF(MA200, 1))
    XG1 = aa1 & aa2 & aa3 & aa4 & aa5
    aa6 = REF(XG1, 1) == True
    aa7 = REF(MACD, 1) < MACD
    aa8 = MA200 > REF(MA200, 1)
    XG = aa6 & aa7 & aa8
    h['XG'] = XG
    return h


def DSMA(X, N):  # 偏差自适应移动平均线   type: (np.ndarray, int) -> np.ndarray
        """
        Deviation Scaled Moving Average (DSMA)    Python by: jqz1226, 2021-12-27
        Referred function from myTT: SUM, DMA
        """
        a1 = math.exp(- 1.414 * math.pi * 2 / N)
        b1 = 2 * a1 * math.cos(1.414 * math.pi * 2 / N)
        c2 = b1
        c3 = -a1 * a1
        c1 = 1 - c2 - c3
        Zeros = np.pad(X[2:] - X[:-2], (2, 0), 'constant')
        Filt = np.zeros(len(X))
        for i in range(len(X)):
            Filt[i] = c1 * (Zeros[i] + Zeros[i - 1]) / 2 + c2 * Filt[i - 1] + c3 * Filt[i - 2]

        RMS = np.sqrt(SUM(np.square(Filt), N) / N)
        ScaledFilt = Filt / RMS
        alpha1 = np.abs(ScaledFilt) * 5 / N
        return DMA(X, alpha1)

def SUMBARSFAST(X, A):
        # type: (np.ndarray, Optional[np.ndarray, float, int]) -> np.ndarray
        """
        通达信SumBars函数的Python实现  by jqz1226
        SumBars函数将X向前累加，直到大于等于A, 返回这个区间的周期数。例如SUMBARS(VOL, CAPITAL),求完全换手的周期数。
        :param X: 数组。被累计的源数据。 源数组中不能有小于0的元素。
        :param A: 数组（一组）或者浮点数（一个）或者整数（一个），累加截止的界限数
        :return:  数组。各K线分别对应的周期数
        """
        if any(X <= 0):   raise ValueError('数组X的每个元素都必须大于0！')

        X = np.flipud(X)  # 倒转
        length = len(X)

        if isinstance(A * 1.0, float):  A = np.repeat(A, length)  # 是单值则转化为数组
        A = np.flipud(A)  # 倒转
        sumbars = np.zeros(length)  # 初始化sumbars为0
        Sigma = np.insert(np.cumsum(X), 0, 0.0)  # 在累加值前面插入一个0.0（元素变多1个，便于引用）

        for i in range(length):
            k = np.searchsorted(Sigma[i + 1:], A[i] + Sigma[i])
            if k < length - i:  # 找到
                sumbars[length - i - 1] = k + 1
        return sumbars.astype(int)

        # ------------------------指标函数---------------------------------------------

def SAR(HIGH, LOW, N=10, S=2, M=20):
        """
        求抛物转向。 例如SAR(10,2,20)表示计算10日抛物转向，步长为2%，步长极限为20%
        Created by: jqz1226, 2021-11-24首次发表于聚宽(www.joinquant.com)

        :param HIGH: high序列
        :param LOW: low序列
        :param N: 计算周期
        :param S: 步长
        :param M: 步长极限
        :return: 抛物转向
        """
        f_step = S / 100;
        f_max = M / 100;
        af = 0.0
        is_long = HIGH[N - 1] > HIGH[N - 2]
        b_first = True
        length = len(HIGH)

        s_hhv = REF(HHV(HIGH, N), 1)  # type: np.ndarray
        s_llv = REF(LLV(LOW, N), 1)  # type: np.ndarray
        sar_x = np.repeat(np.nan, length)  # type: np.ndarray
        for i in range(N, length):
            if b_first:  # 第一步
                af = f_step
                sar_x[i] = s_llv[i] if is_long else s_hhv[i]
                b_first = False
            else:  # 继续多 或者 空
                ep = s_hhv[i] if is_long else s_llv[i]  # 极值
                if (is_long and HIGH[i] > ep) or ((not is_long) and LOW[i] < ep):  # 顺势：多创新高 或者 空创新低
                    af = min(af + f_step, f_max)
                #
                sar_x[i] = sar_x[i - 1] + af * (ep - sar_x[i - 1])

            if (is_long and LOW[i] < sar_x[i]) or ((not is_long) and HIGH[i] > sar_x[i]):  # 反空 或者 反多
                is_long = not is_long
                b_first = True
        return sar_x

def TDX_SAR(High, Low, iAFStep=2, iAFLimit=20):  # type: (np.ndarray, np.ndarray, int, int) -> np.ndarray
        """  通达信SAR算法,和通达信SAR对比完全一致   by: jqz1226, 2021-12-18
        :param High: 最高价序列
        :param Low: 最低价序列
        :param iAFStep: AF步长
        :param iAFLimit: AF极限值
        :return: SAR序列
        """
        af_step = iAFStep / 100;
        af_limit = iAFLimit / 100
        SarX = np.zeros(len(High))  # 初始化返回数组

        # 第一个bar
        bull = True
        af = af_step
        ep = High[0]
        SarX[0] = Low[0]
        # 第2个bar及其以后
        for i in range(1, len(High)):
            # 1.更新：hv, lv, af, ep
            if bull:  # 多
                if High[i] > ep:  # 创新高
                    ep = High[i]
                    af = min(af + af_step, af_limit)
            else:  # 空
                if Low[i] < ep:  # 创新低
                    ep = Low[i]
                    af = min(af + af_step, af_limit)
            # 2.计算SarX
            SarX[i] = SarX[i - 1] + af * (ep - SarX[i - 1])

            # 3.修正SarX
            if bull:
                SarX[i] = max(SarX[i - 1], min(SarX[i], Low[i], Low[i - 1]))
            else:
                SarX[i] = min(SarX[i - 1], max(SarX[i], High[i], High[i - 1]))

            # 4. 判断是否：向下跌破，向上突破
            if bull:  # 多
                if Low[i] < SarX[i]:  # 向下跌破，转空
                    bull = False
                    tmp_SarX = ep  # 上阶段的最高点
                    ep = Low[i]
                    af = af_step
                    if High[i - 1] == tmp_SarX:  # 紧邻即最高点
                        SarX[i] = tmp_SarX
                    else:
                        SarX[i] = tmp_SarX + af * (ep - tmp_SarX)
            else:  # 空
                if High[i] > SarX[i]:  # 向上突破, 转多
                    bull = True
                    ep = High[i]
                    af = af_step
                    SarX[i] = min(Low[i], Low[i - 1])
        # end for
        return SarX

#望大家能提交更多指标和函数  https://github.com/mpquant/MyTT
class ChipDistribution():

    def __init__(self,trade_date,turnover_rate,volume,open,high,low,close,avg):
        self.Chip = {} # 当前获利盘
        self.ChipList = {}  # 所有的获利盘的
        self.trade_date=trade_date
        self.turnover_rate=turnover_rate
        self.volume=volume
        self.open=open
        self.high=high
        self.low=low
        self.close=close
        self.avg=avg
    #
    # def get_data(self):
    #     # trade_date, turnover_rate, amount, volume,open,high,low,close, avg_price=  amount /vol/100
    #     self.data = pd.read_csv('test.csv')
    #
     #均值法 计算筹码
    def calcuJUN(self,dateT,highT, lowT, volT, TurnoverRateT, A, minD):

        x =[]
        l = (highT - lowT) / minD
        for i in range(int(l)):
            x.append(round(lowT + i * minD, 2))
        length = len(x)
        eachV = volT/length
        for i in self.Chip:
            self.Chip[i] = self.Chip[i] *(1 -TurnoverRateT * A)
        for i in x:
            if i in self.Chip:
                self.Chip[i] += eachV *(TurnoverRateT * A)
            else:
                self.Chip[i] = eachV *(TurnoverRateT * A)
        import copy
        self.ChipList[dateT] = copy.deepcopy(self.Chip)


# 三角函数发计算筹码，默认三角函数法
    def calcuSin(self,dateT,highT, lowT,avgT, volT,TurnoverRateT,minD,A):
        x =[]
        l = (highT - lowT) / minD
        for i in range(int(l)):
            x.append(round(lowT + i * minD, 2))
        length = len(x)
        #计算仅仅今日的筹码分布
        tmpChip = {}
        # eachV = volT/length
        #极限法分割去逼近
        for i in x:
            x1 = i
            x2 = i + minD
            h = 2 / (highT - lowT)
            s= 0
            if i < avgT:
                y1 = h /(avgT - lowT) * (x1 - lowT)
                y2 = h /(avgT - lowT) * (x2 - lowT)
                s = minD *(y1 + y2) /2
                s = s * volT
            else:
                y1 = h /(highT - avgT) *(highT - x1)
                y2 = h /(highT - avgT) *(highT - x2)

                s = minD *(y1 + y2) /2
                s = s * volT
            tmpChip[i] = s


        for i in self.Chip:
            self.Chip[i] = self.Chip[i] *(1 -TurnoverRateT * A)

        for i in tmpChip:
            if i in self.Chip:
                self.Chip[i] += tmpChip[i] *(TurnoverRateT * A)
            else:
                self.Chip[i] = tmpChip[i] *(TurnoverRateT * A)
        import copy
        self.ChipList[dateT] = copy.deepcopy(self.Chip)


    def calcu(self,dateT,highT, lowT,avgT, volT, TurnoverRateT,minD = 0.01, flag=1 , AC=1):
        if flag ==1:
            self.calcuSin(dateT,highT, lowT,avgT, volT, TurnoverRateT,A=AC, minD=minD)
        elif flag ==2:
            self.calcuJUN(dateT,highT, lowT, volT, TurnoverRateT, A=AC, minD=minD)

    def calcuChip(self, flag=1, AC=1):  #flag 使用哪个计算方式,    AC 衰减系数,受限要选择筹码计算方式，   如果不用default 计算 winner ,cost, 那么这个要先执行，
        low = self.low
        high = self.high
        vol = self.volume
        TurnoverRate = self.turnover_rate
        avg = self.avg
        date = self.trade_date

        for i in range(len(date)):
        #     if i < 90:
        #         continue

            highT = high[i]
            lowT = low[i]
            volT = vol[i]
            TurnoverRateT = TurnoverRate[i]
            avgT = avg[i]
            # print(date[i])
            dateT = date[i]
            self.calcu(dateT,highT, lowT,avgT, volT, TurnoverRateT/100, flag=flag, AC=AC)  # 东方财富的小数位要注意，兄弟萌。我不除100懵逼了

        # 计算winner, 传入参数是 价位
    def winner(self,p=None):
            Profit = []
            date = self.trade_date

            if p == None:  # 不输入默认close
                p = self.close
                count = 0
                for i in self.ChipList:
                    # 计算目前的比例

                    Chip = self.ChipList[i]
                    total = 0
                    be = 0
                    for i in Chip:
                        total += Chip[i]
                        if i < p[count]:
                            be += Chip[i]
                    if total != 0:
                        bili = be / total
                    else:
                        bili = 0
                    count += 1
                    Profit.append(bili)
            else:
                PP = self.close*p
                count = 0
                for i in self.ChipList:
                    # 计算目前的比例

                    Chip = self.ChipList[i]
                    total = 0
                    be = 0
                    for i in Chip:
                        total += Chip[i]
                        if i < PP[count]:
                            be += Chip[i]
                    if total != 0:
                        bili = be / total
                    else:
                        bili = 0
                    count += 1
                    Profit.append(bili)
            # import matplotlib.pyplot as plt
            # plt.plot(date[len(date) - 200:-1], Profit[len(date) - 200:-1])
            # plt.show()

            return Profit
 # p 是一个值
    def winner_default(self,p=None):
            self.calcuChip(flag=1, AC=1)
            Profit = []
            date = self.trade_date

            if p == None:  # 不输入默认close
                p = self.close
                count = 0
                for i in self.ChipList:
                    # 计算目前的比例

                    Chip = self.ChipList[i]
                    total = 0
                    be = 0
                    for i in Chip:
                        total += Chip[i]
                        if i < p[count]:
                            be += Chip[i]
                    if total != 0:
                        bili = be / total
                    else:
                        bili = 0
                    count += 1
                    Profit.append(bili)
            else:

                for i in self.ChipList:
                    # 计算目前的比例

                    Chip = self.ChipList[i]
                    total = 0
                    be = 0
                    for i in Chip:
                        total += Chip[i]
                        if i < p:
                            be += Chip[i]
                    if total != 0:
                        bili = be / total
                    else:
                        bili = 0
                    Profit.append(bili)

            # import matplotlib.pyplot as plt
            # plt.plot(date[len(date) - 200:-1], Profit[len(date) - 200:-1])
            # plt.show()

            return pd.Series(Profit).values
    def lwinner(self,N = 5, p=None):

         trade_date_copy = copy.deepcopy(self.trade_date)
         turnover_rate_copy = copy.deepcopy(self.turnover_rate)
         volume_copy = copy.deepcopy(self.volume)
         open_copy=copy.deepcopy(self.open )
         high_copy=copy.deepcopy(self.high )
         low_copy=copy.deepcopy(self.low )
         close_copy=copy.deepcopy(self.close )
         avg_copy=copy.deepcopy(self.avg )
         datet = self.trade_date
         ans = []
         for i in range(len(datet)):
            print(datet[i])
            if i < N:
                ans.append(None)
                continue
            open_i = open_copy[i-N:i]
            trade_date_i=trade_date_copy[i-N:i]
            avg_i=avg_copy[i-N:i]
            low_i=low_copy[i-N:i]
            close_i=close_copy[i-N:i]
            high_i=high_copy[i-N:i]
            volume_i=volume_copy[i-N:1]
            turnover_rate_i=turnover_rate_copy[i-N:1]

            # self.data.index= range(0,N)
            self.__init__(trade_date_i,turnover_rate_i,volume_i,open_i,high_i,low_i,close_i,avg_i)   #初始化，质控
            self.calcuChip()    #使用默认计算方式
            a = self.winner(p)
            ans.append(a[-1])
        # import matplotlib.pyplot as plt
        # plt.plot(date[len(date) - 60:-1], ans[len(date) - 60:-1])
        # plt.show()

         # self.data = data
         self.trade_date = trade_date_copy
         self.turnover_rate = turnover_rate_copy
         self.volume = volume_copy
         self.open = open_copy
         self.high = high_copy
         self.low = low_copy
         self.close = close_copy
         self.avg = avg_copy
         return ans


    # COST(10), 表示10 % 获利盘的价格是多少, 即有10 % 的持仓量在该价格以下, 其余90 % 在该价格以上, 为套牢盘，该函数仅对日线分析周期有效。
    # 这个也是根据筹码分布就可以简单计算，cost和winner就像两个相反的东西
    def cost(self,N):
        self.calcuChip(flag=1, AC=1)
        date = self.trade_date
        N = N / 100  # 转换成百分比
        ans = []
        for i in self.ChipList:  # 我的ChipList本身就是有顺序的
            Chip = self.ChipList[i]
            ChipKey = sorted(Chip.keys())  # 排序
            total = 0  # 当前比例
            sumOf = 0  # 所有筹码的总和
            if len(Chip) ==0:
                ans.append(0)
            else:
              for j in Chip:
                    sumOf += Chip[j]

              for j in ChipKey:
                    tmp = Chip[j]
                    tmp = tmp / sumOf
                    total += tmp
                    if total > N:
                        ans.append(j)
                        break

        # import matplotlib.pyplot as plt
        # plt.plot(date[len(date) - 1000:-1], ans[len(date) - 1000:-1])
        # plt.show()
        return pd.Series(ans).values
  
  
    def cost_default(self,N):
        self.calcuChip(flag=1, AC=1)
        date = self.trade_date

        N = N / 100  # 转换成百分比
        ans = []
        for i in self.ChipList:  # 我的ChipList本身就是有顺序的
            Chip = self.ChipList[i]
            ChipKey = sorted(Chip.keys())  # 排序
            total = 0  # 当前比例
            sumOf = 0  # 所有筹码的总和
            if len(Chip) ==0:
                ans.append(0)
            else:
              for j in Chip:
                    sumOf += Chip[j]

              for j in ChipKey:
                    tmp = Chip[j]
                    tmp = tmp / sumOf
                    total += tmp
                    if total > N:
                        ans.append(j)
                        break
        # import matplotlib.pyplot as plt
        # plt.plot(date[len(date) - 1000:-1], ans[len(date) - 1000:-1])
        # plt.show()
        return pd.Series(ans).values


def virtual_vol(vol_today):
    from datetime import datetime
    now = datetime.now()
    year = now.year
    month = now.month
    day = now.day
    hour = now.hour
    minute = now.minute
    kaipan = datetime(year=year, month=month, day=day, hour=9, minute=30, second=0)
    shangwu = datetime(year=year, month=month, day=day, hour=11, minute=30, second=0)
    zhongwu = datetime(year=year, month=month, day=day, hour=13, minute=0, second=0)

    xiawu = datetime(year=year, month=month, day=day, hour=15, minute=0, second=0)
    virtual_vol = vol_today
    if now < shangwu:
        shangwumin = (now - kaipan).total_seconds() // 60

        virtual_vol = (vol_today / shangwumin) * 200

    elif now > shangwu and now < zhongwu:
        zhongwu_min = 120
        virtual_vol = (vol_today / zhongwu_min) * 220

    elif now > zhongwu and now < xiawu:
        xiawumin = 120 + (now - zhongwu).total_seconds() // 60
        virtual_vol = (vol_today / xiawumin) * 230
    else:
        panhoumin = 120 + 120
        virtual_vol = vol_today

    return virtual_vol